{
 "cells": [
  {
   "cell_type": "raw",
   "id": "e515a630",
   "metadata": {},
   "source": [
    "---\n",
    "Copyright 2025 The JAX Authors.\n",
    "\n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "you may not use this file except in compliance with the License.\n",
    "You may obtain a copy of the License at\n",
    "\n",
    "    https://www.apache.org/licenses/LICENSE-2.0\n",
    "\n",
    "Unless required by applicable law or agreed to in writing, software\n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "See the License for the specific language governing permissions and\n",
    "limitations under the License.\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25ac15cc",
   "metadata": {},
   "source": [
    "# Autodidax2, part 1: JAX from scratch, again"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6999020a",
   "metadata": {},
   "source": [
    "If you want to understand how JAX works you could trying reading the code. But\n",
    "the code is complicated, often for no good reason. This notebook presents a\n",
    "stripped-back version without the cruft. It's a minimal version of JAX from\n",
    "first principles. Enjoy!"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62dde49f",
   "metadata": {},
   "source": [
    "## Main idea: context-sensitive interpretation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d13d5272",
   "metadata": {},
   "source": [
    "JAX is two things:\n",
    "  1. a set of primitive operations (roughly the NumPy API)\n",
    "  2. a set of interpreters over those primitives (compilation, AD, etc.)\n",
    "\n",
    "In this minimal version of JAX we'll start with just two primitive operations,\n",
    "addition and multiplication, and we'll add interpreters one by one. Suppose we\n",
    "have a user-defined function like this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9f179429",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-02-12T20:21:55.809256Z",
     "iopub.status.busy": "2025-02-12T20:21:55.808149Z",
     "iopub.status.idle": "2025-02-12T20:21:55.827374Z",
     "shell.execute_reply": "2025-02-12T20:21:55.826143Z"
    }
   },
   "outputs": [],
   "source": [
    "def foo(x):\n",
    "  return mul(x, add(x, 3.0))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "809d67a6",
   "metadata": {},
   "source": [
    "We want to be able to interpret `foo` in different ways without changing its\n",
    "implementation: we want to evaluate it on concrete values, differentiate it,\n",
    "stage it out to an IR, compile it and so on."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4235f52",
   "metadata": {},
   "source": [
    "Here's how we'll do it. For each of these interpretations we'll define an\n",
    "`Interpreter` object with a rule for handling each primitive operation. We'll\n",
    "keep track of the *current* interpreter using a global context variable. The\n",
    "user-facing functions `add` and `mul` will dispatch to the current\n",
    "interpreter. At the beginning of the program the current interpreter will be\n",
    "the \"evaluating\" interpreter which just evaluates the operations on ordinary\n",
    "concrete data. Here's what this all looks like so far."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "65b26bdc",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-02-12T20:21:55.830948Z",
     "iopub.status.busy": "2025-02-12T20:21:55.830603Z",
     "iopub.status.idle": "2025-02-12T20:21:55.842672Z",
     "shell.execute_reply": "2025-02-12T20:21:55.841832Z"
    }
   },
   "outputs": [],
   "source": [
    "from enum import Enum, auto\n",
    "from contextlib import contextmanager\n",
    "from typing import Any\n",
    "\n",
    "# The full (closed) set of primitive operations\n",
    "class Op(Enum):\n",
    "  add = auto()  # addition on floats\n",
    "  mul = auto()  # multiplication on floats\n",
    "\n",
    "# Interpreters have rules for handling each primitive operation.\n",
    "class Interpreter:\n",
    "  def interpret_op(self, op: Op, args: tuple[Any, ...]):\n",
    "    assert False, \"subclass should implement this\"\n",
    "\n",
    "# Our first interpreter is the \"evaluating interpreter\" which performs ordinary\n",
    "# concrete evaluation.\n",
    "class EvalInterpreter:\n",
    "  def interpret_op(self, op, args):\n",
    "    assert all(isinstance(arg, float) for arg in args)\n",
    "    match op:\n",
    "      case Op.add:\n",
    "        x, y = args\n",
    "        return x + y\n",
    "      case Op.mul:\n",
    "        x, y = args\n",
    "        return x * y\n",
    "      case _:\n",
    "        raise ValueError(f\"Unrecognized primitive op: {op}\")\n",
    "\n",
    "# The current interpreter is initially the evaluating interpreter.\n",
    "current_interpreter = EvalInterpreter()\n",
    "\n",
    "# A context manager for temporarily changing the current interpreter\n",
    "@contextmanager\n",
    "def set_interpreter(new_interpreter):\n",
    "  global current_interpreter\n",
    "  prev_interpreter = current_interpreter\n",
    "  try:\n",
    "    current_interpreter = new_interpreter\n",
    "    yield\n",
    "  finally:\n",
    "    current_interpreter = prev_interpreter\n",
    "\n",
    "# The user-facing functions `mul` and `add` dispatch to the current interpreter.\n",
    "def add(x, y): return current_interpreter.interpret_op(Op.add, (x, y))\n",
    "def mul(x, y): return current_interpreter.interpret_op(Op.mul, (x, y))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7d4ff3a",
   "metadata": {},
   "source": [
    "At this point we can call `foo` with ordinary concrete inputs and see the\n",
    "results:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5aa8511c",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-02-12T20:21:55.846387Z",
     "iopub.status.busy": "2025-02-12T20:21:55.846085Z",
     "iopub.status.idle": "2025-02-12T20:21:55.850202Z",
     "shell.execute_reply": "2025-02-12T20:21:55.849420Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10.0\n"
     ]
    }
   ],
   "source": [
    "print(foo(2.0))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "175587ca",
   "metadata": {},
   "source": [
    "## Aside: forward-mode automatic differentiation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e003ba3f",
   "metadata": {},
   "source": [
    "For our second interpreter we're going to try forward-mode automatic\n",
    "differentiation (AD). Here's a quick introduction to forward-mode AD in case\n",
    "this is the first time you've come across it. Otherwise skip ahead to the\n",
    "\"JVPInterprer\" section."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3bb1dde9",
   "metadata": {},
   "source": [
    "Suppose we're interested in the derivative of `foo(x)` evaluated at `x=2.0`.\n",
    "We could approximate it with finite differences:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9151fcd4",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-02-12T20:21:55.852192Z",
     "iopub.status.busy": "2025-02-12T20:21:55.852015Z",
     "iopub.status.idle": "2025-02-12T20:21:55.855275Z",
     "shell.execute_reply": "2025-02-12T20:21:55.854676Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7.000009999913458\n"
     ]
    }
   ],
   "source": [
    "print((foo(2.00001) - foo(2.0)) / 0.00001)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c3ce8ae",
   "metadata": {},
   "source": [
    "The answer is close to 7.0 as expected. But computing it this way required two\n",
    "evaluations of the function (not to mention the roundoff error and truncation\n",
    "error). Here's a funny thing though. We can almost get the answer with a\n",
    "single evaluation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "cba962a2",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-02-12T20:21:55.857141Z",
     "iopub.status.busy": "2025-02-12T20:21:55.856974Z",
     "iopub.status.idle": "2025-02-12T20:21:55.859864Z",
     "shell.execute_reply": "2025-02-12T20:21:55.859432Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10.0000700001\n"
     ]
    }
   ],
   "source": [
    "print(foo(2.00001))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a256c797",
   "metadata": {},
   "source": [
    "The answer we're looking for, 7.0, is right there in the insignificant digits!"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "171a1aab",
   "metadata": {},
   "source": [
    "Here's one way to think about what's happening. The initial argument to `foo`,\n",
    "`2.00001`, carries two pieces of data: a \"primal\" value, 2.0, and a \"tangent\"\n",
    "value, `1.0`. The representation of this primal-tangent pair, `2.00001`, is\n",
    "the sum of the two, with the tangent scaled by a small fixed epsilon, `1e-5`.\n",
    "Ordinary evaluation of `foo(2.00001)` propagates this primal-tangent pair,\n",
    "producing `10.0000700001` as the result. The primal and tangent components are\n",
    "well separated in scale so we can visually interpret the result as the\n",
    "primal-tangent pair (10.0, 7.0), ignoring the the ~1e-10 truncation error at\n",
    "the end."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47420177",
   "metadata": {},
   "source": [
    "The idea with forward-mode differentiation is to do the same thing but exactly\n",
    "and explicitly (eyeballing floats doesn't really scale). We'll represent the\n",
    "primal-tangent pair as an actual pair instead of folding them both into a\n",
    "single floating point number. For each primitive operation we'll have a rule\n",
    "that describes how to propagate these primal tangent pairs. Let's work out the\n",
    "rules for our two primitives."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "309dc70d",
   "metadata": {},
   "source": [
    "Addition is easy. Consider `x + y` where `x = xp + xt * eps` and `y = yp + yt * eps`\n",
    "(\"p\" for \"primal\", \"t\" for \"tangent\"):\n",
    "\n",
    "     x + y = (xp + xt * eps) + (yp + yt * eps)\n",
    "           =   (xp + yp)             # primal component\n",
    "             + (xt + yt) * eps       # tangent component\n",
    "\n",
    "The result is a first-order polynomial in `eps` and we can read off the\n",
    "primal-tangent pair as (xp + yp, xt + yt)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59302b21",
   "metadata": {},
   "source": [
    "Multiplication is more interesting:\n",
    "\n",
    "     x * y = (xp + xt * eps) * (yp + yt * eps)\n",
    "           =    (xp * yp)                        # primal component\n",
    "              + (xp * yt + xt * yp) * eps        # tangent component\n",
    "              + (xt * yt)           * eps * eps  # quadratic component, vanishes in the eps->0 limit\n",
    "\n",
    "Now we have a second order polynomial. But as epsilon goes to zero the\n",
    "quadratic term vanishes and our primal-tangent pair\n",
    "is just `(xp * yp, xp * yt + xt * yp)`\n",
    "(In our earlier example with finite `eps` this term not vanishing is\n",
    "why we had the 1e-10 \"truncation error\")."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37fa5063",
   "metadata": {},
   "source": [
    "Putting this into code, we can write down the forward-AD rules for addition\n",
    "and multiplication and express `foo` in terms of these:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "57222038",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-02-12T20:21:55.862460Z",
     "iopub.status.busy": "2025-02-12T20:21:55.862160Z",
     "iopub.status.idle": "2025-02-12T20:21:55.868704Z",
     "shell.execute_reply": "2025-02-12T20:21:55.867858Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DualNumber(primal=10.0, tangent=7.0)\n"
     ]
    }
   ],
   "source": [
    "from dataclasses import dataclass\n",
    "\n",
    "# A primal-tangent pair is conventionally called a \"dual number\"\n",
    "@dataclass\n",
    "class DualNumber:\n",
    "  primal  : float\n",
    "  tangent : float\n",
    "\n",
    "def add_dual(x : DualNumber, y: DualNumber) -> DualNumber:\n",
    "  return DualNumber(x.primal + y.primal, x.tangent + y.tangent)\n",
    "\n",
    "def mul_dual(x : DualNumber, y: DualNumber) -> DualNumber:\n",
    "  return DualNumber(x.primal * y.primal, x.primal * y.tangent + x.tangent * y.primal)\n",
    "\n",
    "def foo_dual(x : DualNumber) -> DualNumber:\n",
    "  return mul_dual(x, add_dual(x, DualNumber(3.0, 0.0)))\n",
    "\n",
    "print (foo_dual(DualNumber(2.0, 1.0)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54947cc7",
   "metadata": {},
   "source": [
    "That works! But rewriting `foo` to use the `_dual` versions of addition and\n",
    "multiplication was a bit tedious. Let's get back to the main program and use\n",
    "our interpretation machinery to do the rewrite automatically."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25edb5f4",
   "metadata": {},
   "source": [
    "## JVP Interpreter"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "945933c6",
   "metadata": {},
   "source": [
    "We'll set up a new interpreter called `JVPInterpreter` (\"JVP\" for\n",
    "\"Jacobian-vector product\") which propagates these dual numbers instead of\n",
    "ordinary values. The `JVPInterpreter` has methods 'add' and 'mul' that operate\n",
    "on dual number. They cast constant arguments to dual numbers as needed by\n",
    "calling `JVPInterpreter.lift`. In our manually rewritten version above we did\n",
    "that by replacing the literal `3.0` with `DualNumber(3.0, 0.0)`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d17a4fb0",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-02-12T20:21:55.871469Z",
     "iopub.status.busy": "2025-02-12T20:21:55.871220Z",
     "iopub.status.idle": "2025-02-12T20:21:55.880456Z",
     "shell.execute_reply": "2025-02-12T20:21:55.879794Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10.0, 7.0)\n"
     ]
    }
   ],
   "source": [
    "# This is like DualNumber above except that is also has a pointer to the\n",
    "# interpreter it belongs to, which is needed to avoid \"perturbation confusion\"\n",
    "# in higher order differentiation.\n",
    "@dataclass\n",
    "class TaggedDualNumber:\n",
    "  interpreter : Interpreter\n",
    "  primal  : float\n",
    "  tangent : float\n",
    "\n",
    "class JVPInterpreter(Interpreter):\n",
    "  def __init__(self, prev_interpreter: Interpreter):\n",
    "    # We keep a pointer to the interpreter that was current when this\n",
    "    # interpreter was first invoked. That's the context in which our\n",
    "    # rules should run.\n",
    "    self.prev_interpreter = prev_interpreter\n",
    "\n",
    "  def interpret_op(self, op, args):\n",
    "    args = tuple(self.lift(arg) for arg in args)\n",
    "    with set_interpreter(self.prev_interpreter):\n",
    "      match op:\n",
    "        case Op.add:\n",
    "          # Notice that we use `add` and `mul` here, which are the\n",
    "          # interpreter-dispatching functions defined earlier.\n",
    "          x, y = args\n",
    "          return self.dual_number(\n",
    "              add(x.primal, y.primal),\n",
    "              add(x.tangent, y.tangent))\n",
    "\n",
    "        case Op.mul:\n",
    "          x, y = args\n",
    "          x = self.lift(x)\n",
    "          y = self.lift(y)\n",
    "          return self.dual_number(\n",
    "              mul(x.primal, y.primal),\n",
    "              add(mul(x.primal, y.tangent), mul(x.tangent, y.primal)))\n",
    "\n",
    "  def dual_number(self, primal, tangent):\n",
    "    return TaggedDualNumber(self, primal, tangent)\n",
    "\n",
    "  # Lift a constant value (constant with respect to this interpreter) to\n",
    "  # a TaggedDualNumber.\n",
    "  def lift(self, x):\n",
    "    if isinstance(x, TaggedDualNumber) and x.interpreter is self:\n",
    "      return x\n",
    "    else:\n",
    "      return self.dual_number(x, 0.0)\n",
    "\n",
    "def jvp(f, primal, tangent):\n",
    "  jvp_interpreter = JVPInterpreter(current_interpreter)\n",
    "  dual_number_in = jvp_interpreter.dual_number(primal, tangent)\n",
    "  with set_interpreter(jvp_interpreter):\n",
    "    result = f(dual_number_in)\n",
    "  dual_number_out = jvp_interpreter.lift(result)\n",
    "  return dual_number_out.primal, dual_number_out.tangent\n",
    "\n",
    "# Let's try it out:\n",
    "print(jvp(foo, 2.0, 1.0))\n",
    "\n",
    "# Because we were careful to consider nesting interpreters, higher-order AD\n",
    "# works out of the box:\n",
    "\n",
    "def derivative(f, x):\n",
    "  _, tangent = jvp(f, x, 1.0)\n",
    "  return tangent\n",
    "\n",
    "def nth_order_derivative(n, f, x):\n",
    "  if n == 0:\n",
    "    return f(x)\n",
    "  else:\n",
    "    return derivative(lambda x: nth_order_derivative(n-1, f, x), x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "3acc3839",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-02-12T20:21:55.882187Z",
     "iopub.status.busy": "2025-02-12T20:21:55.882009Z",
     "iopub.status.idle": "2025-02-12T20:21:55.885190Z",
     "shell.execute_reply": "2025-02-12T20:21:55.884635Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10.0\n"
     ]
    }
   ],
   "source": [
    "print(nth_order_derivative(0, foo, 2.0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "187eb028",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-02-12T20:21:55.886848Z",
     "iopub.status.busy": "2025-02-12T20:21:55.886685Z",
     "iopub.status.idle": "2025-02-12T20:21:55.889507Z",
     "shell.execute_reply": "2025-02-12T20:21:55.889081Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7.0\n"
     ]
    }
   ],
   "source": [
    "print(nth_order_derivative(1, foo, 2.0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "9f0dde6d",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-02-12T20:21:55.891061Z",
     "iopub.status.busy": "2025-02-12T20:21:55.890896Z",
     "iopub.status.idle": "2025-02-12T20:21:55.894142Z",
     "shell.execute_reply": "2025-02-12T20:21:55.893701Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.0\n"
     ]
    }
   ],
   "source": [
    "print(nth_order_derivative(2, foo, 2.0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "4d086fb3",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-02-12T20:21:55.895591Z",
     "iopub.status.busy": "2025-02-12T20:21:55.895398Z",
     "iopub.status.idle": "2025-02-12T20:21:55.898277Z",
     "shell.execute_reply": "2025-02-12T20:21:55.897870Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0\n"
     ]
    }
   ],
   "source": [
    "# The rest are zero because `foo` is only a second-order polymonial\n",
    "print(nth_order_derivative(3, foo, 2.0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "e3164405",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-02-12T20:21:55.899736Z",
     "iopub.status.busy": "2025-02-12T20:21:55.899545Z",
     "iopub.status.idle": "2025-02-12T20:21:55.902719Z",
     "shell.execute_reply": "2025-02-12T20:21:55.902303Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0\n"
     ]
    }
   ],
   "source": [
    "print(nth_order_derivative(4, foo, 2.0))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e51ca61",
   "metadata": {},
   "source": [
    "There are some subtleties worth discussing. First, how do you tell if\n",
    "something is constant with respect to differentiation? It's tempting to say\n",
    "\"it's a constant if and only if it's not a dual number\". But actually dual\n",
    "numbers created by a *different* JVPInterpreter also need to be considered\n",
    "constants with respect to the JVPInterpreter we're currently handling. That's\n",
    "why we need the `x.interpreter is self` check in `JVPInterpreter.lift`. This\n",
    "comes up in higher order differentiation when there are multiple JVPInterprers\n",
    "in scope. The sort of bug where you accidentally interpret a dual number from\n",
    "a different interpreter as non-constant is sometimes called \"perturbation\n",
    "confusion\" in the literature. Here's an example program that would have given\n",
    "the wrong answer if we hadn't had the `and x.interpreter is self` check in\n",
    "`JVPInterpreter.lift`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "ae1449a0",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-02-12T20:21:55.904294Z",
     "iopub.status.busy": "2025-02-12T20:21:55.904105Z",
     "iopub.status.idle": "2025-02-12T20:21:55.907284Z",
     "shell.execute_reply": "2025-02-12T20:21:55.906874Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0\n"
     ]
    }
   ],
   "source": [
    "def f(x):\n",
    "  # g is constant in its (ignored) argument `y`. Its derivative should be zero\n",
    "  # but our AD will mess it up if we don't distinguish perturbations from\n",
    "  # different interpreters.\n",
    "  def g(y):\n",
    "    return x\n",
    "  should_be_zero = derivative(g, 0.0)\n",
    "  return mul(x, should_be_zero)\n",
    "\n",
    "print(derivative(f, 0.0))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "884d6f62",
   "metadata": {},
   "source": [
    "Another subtlety: `JVPInterpreter.add` and `JVPInterpreter.mul` describe\n",
    "addition and multiplication on dual numbers in terms of addition and\n",
    "multiplication on the primal and tangent components. But we don't use ordinary\n",
    "`+` and `*` for this. Instead we use our own `add` and `mul` functions which\n",
    "dispatch to the current interpreter. Before calling them we set the current\n",
    "interpreter to be the *previous* interpreter, i.e. the interpreter that was\n",
    "current when `JVPInterpreter` was first invoked. If we didn't do this we'd\n",
    "have an infinite recursion, with `add` and `mul` dispatching to\n",
    "`JVPInterpreter` endlessly. The advantage of using own `add` and `mul` instead\n",
    "of ordinary `+` and `*` is that it means we can nest these interpreters and do\n",
    "higher-order AD."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e03446b3",
   "metadata": {},
   "source": [
    "At this point you might be wondering: have we just reinvented operator\n",
    "overloading? Python overloads the infix ops `+` and `*` to dispatch to the\n",
    "argument's `__add__` and `__mul__`. Could we have just used that mechanism\n",
    "instead of this whole interpreter business? Yes, actually. Indeed, the earlier\n",
    "automatic differentiation (AD) literature uses the term \"operator overloading\"\n",
    "to describe this style of AD implementation. One detail is that we can't rely\n",
    "exclusively on Python built-in overloading because that only lets us overload\n",
    "a handful of built-in infix ops whereas we eventually want to overload\n",
    "numpy-level operations like `sin` and `cos`. So we need our own mechanism."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e2035ea",
   "metadata": {},
   "source": [
    "But there's a more important difference: our dispatch is based on *context*\n",
    "whereas traditional Python-style overloading is based on *data*. This is\n",
    "actually a recent development for JAX. The earliest versions of JAX looked\n",
    "more like traditional data-based overloading. An interpreter (a \"trace\" in JAX\n",
    "jargon) for an operation would be chosen based on data attached to the\n",
    "arguments to that operation. We've gradually made the interpreter-dispatch\n",
    "decision rely more and more on context rather than data (omnistaging [link],\n",
    "stackless [link]). The reason to prefer context-based interpretation over\n",
    "data-based interpretation is that it makes the implementation much simpler."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e5e55ee",
   "metadata": {},
   "source": [
    "All that said, we do *also* want to take advantage of Python's built-in\n",
    "overloading mechanism. That way we get the syntactic convenience of using\n",
    "infix operators `+` and `*` instead of writing out `add(..)` and `mul(..)`.\n",
    "But we'll put that aside for now."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c9a22b61",
   "metadata": {},
   "source": [
    "# 3. Staging to an untyped IR"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "480e2328",
   "metadata": {},
   "source": [
    "The two program transformations we've seen so far -- evaluation and JVP --\n",
    "both traverse the input program from top to bottom. They visit the operations\n",
    "one by one in the same order as ordinary evaluation. A convenient thing about\n",
    "top-to-bottom transformations is that they can be implemented eagerly, or\n",
    "\"online\", meaning that we can evaluate the program from top to bottom and\n",
    "perform the necessary transformations as we go. We never look at the entire\n",
    "program at once."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e4adad7b",
   "metadata": {},
   "source": [
    "But not all transformations work this way. For example, dead-code elimination\n",
    "requires traversing from bottom to top, collecting usage statistics on the way\n",
    "up and eliminating pure operations whose results have no uses. Another\n",
    "bottom-to-top transformation is AD transposition, which we use to implement\n",
    "reverse-mode AD. For these we need to first \"stage\" the program into an IR\n",
    "(internal representation), a data structure representing the program, which we\n",
    "can then traverse in any order we like. Building this IR from a Python program\n",
    "will be the goal of our third and final interpreter."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c8ba558",
   "metadata": {},
   "source": [
    "First, let's define the IR. We'll do an untypes ANF IR to start. A function\n",
    "(we call IR functions \"jaxprs\" in JAX) will have a list of formal parameters,\n",
    "a list of operations, and a return value. Each argument to an operation must\n",
    "be an \"atom\", which is either a variable or a literal. The return value of the\n",
    "function is also an atom."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "2100d92e",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-02-12T20:21:55.909146Z",
     "iopub.status.busy": "2025-02-12T20:21:55.908956Z",
     "iopub.status.idle": "2025-02-12T20:21:55.914391Z",
     "shell.execute_reply": "2025-02-12T20:21:55.913886Z"
    }
   },
   "outputs": [],
   "source": [
    "Var = str           # Variables are just strings in this untyped IR\n",
    "Atom = Var | float  # Atoms (arguments to operations) can be variables or (float) literals\n",
    "\n",
    "# Equation - a single line in our IR like `z = mul(x, y)`\n",
    "@dataclass\n",
    "class Equation:\n",
    "  var  : Var         # The variable name of the result\n",
    "  op   : Op          # The primitive operation we're applying\n",
    "  args : tuple[Atom] # The arguments we're applying the primitive operation to\n",
    "\n",
    "# We call an IR function a \"Jaxpr\", for \"JAX expression\"\n",
    "@dataclass\n",
    "class Jaxpr:\n",
    "  parameters : list[Var]      # The function's formal parameters (arguments)\n",
    "  equations  : list[Equation] # The body of the function, a list of instructions/equations\n",
    "  return_val : Atom           # The function's return value\n",
    "\n",
    "  def __str__(self):\n",
    "    lines = []\n",
    "    lines.append(', '.join(b for b in self.parameters) + ' ->')\n",
    "    for eqn in self.equations:\n",
    "      args_str = ', '.join(str(arg) for arg in eqn.args)\n",
    "      lines.append(f'  {eqn.var} = {eqn.op}({args_str})')\n",
    "    lines.append(self.return_val)\n",
    "    return '\\n'.join(lines)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "36720d5c",
   "metadata": {},
   "source": [
    "To build the IR from a Python function we define a `StagingInterpreter` that\n",
    "takes each operation and adds it to a growing list of all the operations we've\n",
    "seen so far:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "0ed04f2e",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-02-12T20:21:55.916077Z",
     "iopub.status.busy": "2025-02-12T20:21:55.915878Z",
     "iopub.status.idle": "2025-02-12T20:21:55.920161Z",
     "shell.execute_reply": "2025-02-12T20:21:55.919753Z"
    }
   },
   "outputs": [],
   "source": [
    "class StagingInterpreter(Interpreter):\n",
    "  def __init__(self):\n",
    "    self.equations = []         # A mutable list of all the ops we've seen so far\n",
    "    self.name_counter = 0  # Counter for generating unique names\n",
    "\n",
    "  def fresh_var(self):\n",
    "    self.name_counter += 1\n",
    "    return \"v_\" + str(self.name_counter)\n",
    "\n",
    "  def interpret_op(self, op, args):\n",
    "    binder = self.fresh_var()\n",
    "    self.equations.append(Equation(binder, op, args))\n",
    "    return binder\n",
    "\n",
    "def build_jaxpr(f, num_args):\n",
    "  interpreter = StagingInterpreter()\n",
    "  parameters = tuple(interpreter.fresh_var() for _ in range(num_args))\n",
    "  with set_interpreter(interpreter):\n",
    "    result = f(*parameters)\n",
    "  return Jaxpr(parameters, interpreter.equations, result)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1bde02c9",
   "metadata": {},
   "source": [
    "Now we can construct an IR for a Python program and print it out:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "606d2e23",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-02-12T20:21:55.921731Z",
     "iopub.status.busy": "2025-02-12T20:21:55.921538Z",
     "iopub.status.idle": "2025-02-12T20:21:55.924256Z",
     "shell.execute_reply": "2025-02-12T20:21:55.923850Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "v_1 ->\n",
      "  v_2 = Op.add(v_1, 3.0)\n",
      "  v_3 = Op.mul(v_1, v_2)\n",
      "v_3\n"
     ]
    }
   ],
   "source": [
    "print(build_jaxpr(foo, 1))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "67deabe8",
   "metadata": {},
   "source": [
    "We can also evaluate our IR by writing an explicit interpreter that traverses\n",
    "the operations one by one:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "6a20cc84",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-02-12T20:21:55.925838Z",
     "iopub.status.busy": "2025-02-12T20:21:55.925646Z",
     "iopub.status.idle": "2025-02-12T20:21:55.929596Z",
     "shell.execute_reply": "2025-02-12T20:21:55.929187Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10.0\n"
     ]
    }
   ],
   "source": [
    "def eval_jaxpr(jaxpr, args):\n",
    "  # An environment mapping variables to values\n",
    "  env = dict(zip(jaxpr.parameters, args))\n",
    "  def eval_atom(x): return env[x] if isinstance(x, Var) else x\n",
    "  for eqn in jaxpr.equations:\n",
    "    args = tuple(eval_atom(x) for x in eqn.args)\n",
    "    env[eqn.var] = current_interpreter.interpret_op(eqn.op, args)\n",
    "  return eval_atom(jaxpr.return_val)\n",
    "\n",
    "print(eval_jaxpr(build_jaxpr(foo, 1), (2.0,)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6250492",
   "metadata": {},
   "source": [
    "We've written this interpreter in terms of `current_interpreter.interpret_op`\n",
    "which means we've done a full round-trip: interpretable Python program to IR\n",
    "to interpretable Python program. Since the result is \"interpretable\" we can\n",
    "differentiate it again, or stage it out or anything we like:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "831924b8",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-02-12T20:21:55.931176Z",
     "iopub.status.busy": "2025-02-12T20:21:55.930983Z",
     "iopub.status.idle": "2025-02-12T20:21:55.933902Z",
     "shell.execute_reply": "2025-02-12T20:21:55.933490Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10.0, 7.0)\n"
     ]
    }
   ],
   "source": [
    "print(jvp(lambda x: eval_jaxpr(build_jaxpr(foo, 1), (x,)), 2.0, 1.0))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3ac5873",
   "metadata": {},
   "source": [
    "## Up next..."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3451298",
   "metadata": {},
   "source": [
    "That's it for part one of this tutorial. We've done two primitives, three\n",
    "interpreters and the tracing mechanism that weaves them together. In the next\n",
    "part we'll add types other than floats, error handling, compilation,\n",
    "reverse-mode AD and higher-order primitives. Note that the second part is\n",
    "structured differently. Rather than trying to have a top-to-bottom order that\n",
    "obeys both code dependencies (e.g. data structures need to be defined before\n",
    "they're used) and pedagogical dependencies (concepts need to be introduced\n",
    "before they're implemented) we're going with a single file that can be approached\n",
    "in any order."
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "formats": "ipynb,md:myst,py:light"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
